library(tidyverse)
library(tidymodels)
library(modeldata) #dataset cells
theme_set(theme_bw())
data(cells)
cell <- cells
cell
# A tibble: 2,019 x 58
case class angle_ch_1 area_ch_1 avg_inten_ch_1 avg_inten_ch_2 avg_inten_ch_3
<fct> <fct> <dbl> <int> <dbl> <dbl> <dbl>
1 Test PS 143. 185 15.7 4.95 9.55
2 Train PS 134. 819 31.9 207. 69.9
3 Train WS 107. 431 28.0 116. 63.9
4 Train PS 69.2 298 19.5 102. 28.2
5 Test PS 2.89 285 24.3 112. 20.5
6 Test WS 40.7 172 326. 654. 129.
7 Test WS 174. 177 260. 596. 124.
8 Test PS 180. 251 18.3 5.73 17.2
9 Test WS 18.9 495 16.1 89.5 13.7
10 Test WS 153. 384 17.7 89.9 20.4
# ... with 2,009 more rows, and 51 more variables: avg_inten_ch_4 <dbl>,
# convex_hull_area_ratio_ch_1 <dbl>, convex_hull_perim_ratio_ch_1 <dbl>,
# diff_inten_density_ch_1 <dbl>, diff_inten_density_ch_3 <dbl>,
# diff_inten_density_ch_4 <dbl>, entropy_inten_ch_1 <dbl>,
# entropy_inten_ch_3 <dbl>, entropy_inten_ch_4 <dbl>,
# eq_circ_diam_ch_1 <dbl>, eq_ellipse_lwr_ch_1 <dbl>,
# eq_ellipse_oblate_vol_ch_1 <dbl>, eq_ellipse_prolate_vol_ch_1 <dbl>,
# eq_sphere_area_ch_1 <dbl>, eq_sphere_vol_ch_1 <dbl>,
# fiber_align_2_ch_3 <dbl>, fiber_align_2_ch_4 <dbl>,
# fiber_length_ch_1 <dbl>, fiber_width_ch_1 <dbl>, inten_cooc_asm_ch_3 <dbl>,
# inten_cooc_asm_ch_4 <dbl>, inten_cooc_contrast_ch_3 <dbl>,
# inten_cooc_contrast_ch_4 <dbl>, inten_cooc_entropy_ch_3 <dbl>,
# inten_cooc_entropy_ch_4 <dbl>, inten_cooc_max_ch_3 <dbl>,
# inten_cooc_max_ch_4 <dbl>, kurt_inten_ch_1 <dbl>, kurt_inten_ch_3 <dbl>,
# kurt_inten_ch_4 <dbl>, length_ch_1 <dbl>, neighbor_avg_dist_ch_1 <dbl>,
# neighbor_min_dist_ch_1 <dbl>, neighbor_var_dist_ch_1 <dbl>,
# perim_ch_1 <dbl>, shape_bfr_ch_1 <dbl>, shape_lwr_ch_1 <dbl>,
# shape_p_2_a_ch_1 <dbl>, skew_inten_ch_1 <dbl>, skew_inten_ch_3 <dbl>,
# skew_inten_ch_4 <dbl>, spot_fiber_count_ch_3 <int>,
# spot_fiber_count_ch_4 <dbl>, total_inten_ch_1 <int>,
# total_inten_ch_2 <dbl>, total_inten_ch_3 <int>, total_inten_ch_4 <int>,
# var_inten_ch_1 <dbl>, var_inten_ch_3 <dbl>, var_inten_ch_4 <dbl>,
# width_ch_1 <dbl>
glimpse(cell)
Rows: 2,019
Columns: 58
$ case <fct> Test, Train, Train, Train, Test, Test, Te~
$ class <fct> PS, PS, WS, PS, PS, WS, WS, PS, WS, WS, W~
$ angle_ch_1 <dbl> 143.247705, 133.752037, 106.646387, 69.15~
$ area_ch_1 <int> 185, 819, 431, 298, 285, 172, 177, 251, 4~
$ avg_inten_ch_1 <dbl> 15.71186, 31.92327, 28.03883, 19.45614, 2~
$ avg_inten_ch_2 <dbl> 4.954802, 206.878517, 116.315534, 102.294~
$ avg_inten_ch_3 <dbl> 9.548023, 69.916880, 63.941748, 28.217544~
$ avg_inten_ch_4 <dbl> 2.214689, 164.153453, 106.696602, 31.0280~
$ convex_hull_area_ratio_ch_1 <dbl> 1.124509, 1.263158, 1.053310, 1.202625, 1~
$ convex_hull_perim_ratio_ch_1 <dbl> 0.9196827, 0.7970801, 0.9354750, 0.865829~
$ diff_inten_density_ch_1 <dbl> 29.51923, 31.87500, 32.48771, 26.73228, 3~
$ diff_inten_density_ch_3 <dbl> 13.77564, 43.12228, 35.98577, 22.91732, 2~
$ diff_inten_density_ch_4 <dbl> 6.826923, 79.308424, 51.357050, 26.393701~
$ entropy_inten_ch_1 <dbl> 4.969781, 6.087592, 5.883557, 5.420065, 5~
$ entropy_inten_ch_3 <dbl> 4.371017, 6.642761, 6.683000, 5.436732, 5~
$ entropy_inten_ch_4 <dbl> 2.718884, 7.880155, 7.144601, 5.778329, 5~
$ eq_circ_diam_ch_1 <dbl> 15.36954, 32.30558, 23.44892, 19.50279, 1~
$ eq_ellipse_lwr_ch_1 <dbl> 3.060676, 1.558394, 1.375386, 3.391220, 2~
$ eq_ellipse_oblate_vol_ch_1 <dbl> 336.9691, 2232.9055, 802.1945, 724.7143, ~
$ eq_ellipse_prolate_vol_ch_1 <dbl> 110.0963, 1432.8246, 583.2504, 213.7031, ~
$ eq_sphere_area_ch_1 <dbl> 742.1156, 3278.7256, 1727.4104, 1194.9320~
$ eq_sphere_vol_ch_1 <dbl> 1900.996, 17653.525, 6750.985, 3884.084, ~
$ fiber_align_2_ch_3 <dbl> 1.000000, 1.487935, 1.300522, 1.220424, 1~
$ fiber_align_2_ch_4 <dbl> 1.000000, 1.352374, 1.522316, 1.733250, 1~
$ fiber_length_ch_1 <dbl> 26.98132, 64.28230, 21.14115, 43.14112, 3~
$ fiber_width_ch_1 <dbl> 7.410365, 13.167079, 21.141150, 7.404412,~
$ inten_cooc_asm_ch_3 <dbl> 0.011183899, 0.028051061, 0.006862315, 0.~
$ inten_cooc_asm_ch_4 <dbl> 0.050448005, 0.012594975, 0.006141691, 0.~
$ inten_cooc_contrast_ch_3 <dbl> 40.751777, 8.227953, 14.446074, 7.299457,~
$ inten_cooc_contrast_ch_4 <dbl> 13.895439, 6.984046, 16.700843, 13.390884~
$ inten_cooc_entropy_ch_3 <dbl> 7.199458, 6.822138, 7.580100, 6.312641, 6~
$ inten_cooc_entropy_ch_4 <dbl> 5.249744, 7.098988, 7.671478, 7.197026, 5~
$ inten_cooc_max_ch_3 <dbl> 0.07741935, 0.15321477, 0.02835052, 0.162~
$ inten_cooc_max_ch_4 <dbl> 0.17197452, 0.07387141, 0.02319588, 0.077~
$ kurt_inten_ch_1 <dbl> -0.656744087, -0.248769067, -0.293484630,~
$ kurt_inten_ch_3 <dbl> -0.608058268, -0.330783900, 1.051281336, ~
$ kurt_inten_ch_4 <dbl> 0.7258145, -0.2652638, 0.1506140, -0.3472~
$ length_ch_1 <dbl> 26.20779, 47.21855, 28.14303, 37.85957, 3~
$ neighbor_avg_dist_ch_1 <dbl> 370.4543, 174.4442, 158.4774, 206.3344, 2~
$ neighbor_min_dist_ch_1 <dbl> 99.10349, 30.11114, 34.94477, 33.08030, 2~
$ neighbor_var_dist_ch_1 <dbl> 127.96080, 81.38063, 90.43768, 116.89276,~
$ perim_ch_1 <dbl> 68.78338, 154.89876, 84.56460, 101.09107,~
$ shape_bfr_ch_1 <dbl> 0.6651480, 0.5397584, 0.7243116, 0.589162~
$ shape_lwr_ch_1 <dbl> 2.462450, 1.468181, 1.328408, 2.826854, 2~
$ shape_p_2_a_ch_1 <dbl> 1.883006, 2.255810, 1.272193, 2.545840, 2~
$ skew_inten_ch_1 <dbl> 0.45450484, 0.39870467, 0.47248709, 0.881~
$ skew_inten_ch_3 <dbl> 0.46039340, 0.61973079, 0.97137879, 0.999~
$ skew_inten_ch_4 <dbl> 1.2327736, 0.5272631, 0.3247065, 0.604439~
$ spot_fiber_count_ch_3 <int> 1, 4, 2, 4, 1, 1, 0, 2, 1, 1, 1, 0, 0, 2,~
$ spot_fiber_count_ch_4 <dbl> 5, 12, 7, 8, 8, 5, 5, 8, 12, 8, 5, 6, 7, ~
$ total_inten_ch_1 <int> 2781, 24964, 11552, 5545, 6603, 53779, 43~
$ total_inten_ch_2 <dbl> 701, 160998, 47511, 28870, 30306, 107681,~
$ total_inten_ch_3 <int> 1690, 54675, 26344, 8042, 5569, 21234, 20~
$ total_inten_ch_4 <int> 392, 128368, 43959, 8843, 11037, 57231, 4~
$ var_inten_ch_1 <dbl> 12.47468, 18.80923, 17.29564, 13.81897, 1~
$ var_inten_ch_3 <dbl> 7.609035, 56.715352, 37.671053, 30.005643~
$ var_inten_ch_4 <dbl> 2.714100, 118.388139, 49.470524, 24.74953~
$ width_ch_1 <dbl> 10.64297, 32.16126, 21.18553, 13.39283, 1~
colSums(is.na(cell))
case class
0 0
angle_ch_1 area_ch_1
0 0
avg_inten_ch_1 avg_inten_ch_2
0 0
avg_inten_ch_3 avg_inten_ch_4
0 0
convex_hull_area_ratio_ch_1 convex_hull_perim_ratio_ch_1
0 0
diff_inten_density_ch_1 diff_inten_density_ch_3
0 0
diff_inten_density_ch_4 entropy_inten_ch_1
0 0
entropy_inten_ch_3 entropy_inten_ch_4
0 0
eq_circ_diam_ch_1 eq_ellipse_lwr_ch_1
0 0
eq_ellipse_oblate_vol_ch_1 eq_ellipse_prolate_vol_ch_1
0 0
eq_sphere_area_ch_1 eq_sphere_vol_ch_1
0 0
fiber_align_2_ch_3 fiber_align_2_ch_4
0 0
fiber_length_ch_1 fiber_width_ch_1
0 0
inten_cooc_asm_ch_3 inten_cooc_asm_ch_4
0 0
inten_cooc_contrast_ch_3 inten_cooc_contrast_ch_4
0 0
inten_cooc_entropy_ch_3 inten_cooc_entropy_ch_4
0 0
inten_cooc_max_ch_3 inten_cooc_max_ch_4
0 0
kurt_inten_ch_1 kurt_inten_ch_3
0 0
kurt_inten_ch_4 length_ch_1
0 0
neighbor_avg_dist_ch_1 neighbor_min_dist_ch_1
0 0
neighbor_var_dist_ch_1 perim_ch_1
0 0
shape_bfr_ch_1 shape_lwr_ch_1
0 0
shape_p_2_a_ch_1 skew_inten_ch_1
0 0
skew_inten_ch_3 skew_inten_ch_4
0 0
spot_fiber_count_ch_3 spot_fiber_count_ch_4
0 0
total_inten_ch_1 total_inten_ch_2
0 0
total_inten_ch_3 total_inten_ch_4
0 0
var_inten_ch_1 var_inten_ch_3
0 0
var_inten_ch_4 width_ch_1
0 0
# predictor is class
# PS: poorly segmented
# WS: well segmented
cell <- cell %>% select(-case)
cell %>% count(class)
# A tibble: 2 x 2
class n
<fct> <int>
1 PS 1300
2 WS 719
set.seed(2021)
cell_split <- initial_split(cell,
prop = .75,
strata = class)
cell_train <- training(cell_split)
cell_test <- testing(cell_split)
set.seed(2021)
cell_folds <- vfold_cv(cell_train, v = 10)
mlp_nnet_model <-
mlp(hidden_units = tune(), penalty = tune(), epochs = tune()) %>%
set_engine('nnet', trace = 0) %>%
set_mode('classification')
mlp_nnet_rec <-
recipe(class ~ ., data = cell_train) %>%
step_YeoJohnson(all_numeric_predictors()) %>%
step_normalize(all_numeric_predictors()) %>%
step_pca(all_numeric_predictors(), num_comp = tune())
mlp_nnet_wf <-
workflow() %>%
add_recipe(mlp_nnet_rec) %>%
add_model(mlp_nnet_model)
mlp_param <- mlp_nnet_wf %>% parameters()
mlp_param %>% pull_dials_object("hidden_units")
# Hidden Units (quantitative)
Range: [1, 10]
mlp_param %>% pull_dials_object("penalty")
Amount of Regularization (quantitative)
Transformer: log-10
Range (transformed scale): [-10, 0]
mlp_param %>% pull_dials_object("epochs")
# Epochs (quantitative)
Range: [10, 1000]
mlp_param %>% pull_dials_object("num_comp")
# Components (quantitative)
Range: [1, 4]
crox <-
crossing(
hidden_units = 1:3,
penaly = c(0.0, 0.1),
epochs = c(100, 200)
)
crox %>% View
# expand_grid(
# hidden_units = 1:3,
# penalty = c(0.0, 0.1),
# epochs = c(100, 200)
# )
grid_regular(mlp_param, levels = 2)
# A tibble: 16 x 4
hidden_units penalty epochs num_comp
<int> <dbl> <int> <int>
1 1 0.0000000001 10 1
2 10 0.0000000001 10 1
3 1 1 10 1
4 10 1 10 1
5 1 0.0000000001 1000 1
6 10 0.0000000001 1000 1
7 1 1 1000 1
8 10 1 1000 1
9 1 0.0000000001 10 4
10 10 0.0000000001 10 4
11 1 1 10 4
12 10 1 10 4
13 1 0.0000000001 1000 4
14 10 0.0000000001 1000 4
15 1 1 1000 4
16 10 1 1000 4
# mlp_param %>%
# grid_regular(levels = c(hidden_units = 3, penalty = 2, epochs = 2))
library(ggforce)
set.seed(2021)
mlp_param %>%
# The 'original = FALSE' option keeps penalty in log10 units
grid_random(size = 15, original = FALSE) %>%
ggplot(aes(x = .panel_x, y = .panel_y)) +
geom_point() +
geom_blank() +
facet_matrix(vars(hidden_units, penalty, epochs), layer.diag = 2) +
labs(title = "Random design with 15 candidates")
set.seed(2021)
mlp_param %>%
grid_latin_hypercube(size = 15, original = FALSE) %>%
ggplot(aes(x = .panel_x, y = .panel_y)) +
geom_point() +
geom_blank() +
facet_matrix(vars(hidden_units, penalty, epochs), layer.diag = 2) +
labs(title = "Latin Hypercube design with 15 candidates")
mlp_param <- mlp_nnet_wf %>%
parameters() %>%
update(
epochs = epochs(c(50, 200)),
num_comp = num_comp(c(0, 40))
)
roc_metric <- metric_set(roc_auc)
keep_pred <- control_resamples(save_pred = TRUE, save_workflow = TRUE)
set.seed(2021)
mlp_tuning <-
tune_grid(
mlp_nnet_wf,
resamples = cell_folds,
grid = mlp_param %>% grid_regular(levels = 3),
metrics = roc_metric,
control = keep_pred
)
mlp_tuning
# Tuning results
# 10-fold cross-validation
# A tibble: 10 x 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [1362/152~ Fold01 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,312 x ~
2 <split [1362/152~ Fold02 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,312 x ~
3 <split [1362/152~ Fold03 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,312 x ~
4 <split [1362/152~ Fold04 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,312 x ~
5 <split [1363/151~ Fold05 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,231 x ~
6 <split [1363/151~ Fold06 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,231 x ~
7 <split [1363/151~ Fold07 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,231 x ~
8 <split [1363/151~ Fold08 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,231 x ~
9 <split [1363/151~ Fold09 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,231 x ~
10 <split [1363/151~ Fold10 <tibble [81 x 8~ <tibble [0 x 1~ <tibble [12,231 x ~
mlp_tuning %>% collect_metrics()
# A tibble: 81 x 10
hidden_units penalty epochs num_comp .metric .estimator mean n std_err
<int> <dbl> <int> <int> <chr> <chr> <dbl> <int> <dbl>
1 1 1e-10 50 0 roc_auc binary 0.871 10 0.0121
2 5 1e-10 50 0 roc_auc binary 0.871 10 0.00573
3 10 1e-10 50 0 roc_auc binary 0.842 10 0.0102
4 1 1e- 5 50 0 roc_auc binary 0.874 10 0.00707
5 5 1e- 5 50 0 roc_auc binary 0.867 10 0.00909
6 10 1e- 5 50 0 roc_auc binary 0.849 10 0.00906
7 1 1e+ 0 50 0 roc_auc binary 0.893 10 0.00535
8 5 1e+ 0 50 0 roc_auc binary 0.896 10 0.00632
9 10 1e+ 0 50 0 roc_auc binary 0.890 10 0.00438
10 1 1e-10 125 0 roc_auc binary 0.821 10 0.0112
# ... with 71 more rows, and 1 more variable: .config <chr>
mlp_tuning %>% collect_predictions()
# A tibble: 122,634 x 10
id .pred_PS .pred_WS .row num_comp hidden_units penalty epochs class
<chr> <dbl> <dbl> <int> <int> <int> <dbl> <int> <fct>
1 Fold01 0.721 0.279 15 0 1 1e-10 50 PS
2 Fold01 0.621 0.379 27 0 1 1e-10 50 PS
3 Fold01 0.384 0.616 29 0 1 1e-10 50 PS
4 Fold01 0.721 0.279 35 0 1 1e-10 50 PS
5 Fold01 0.380 0.620 50 0 1 1e-10 50 PS
6 Fold01 0.719 0.281 52 0 1 1e-10 50 PS
7 Fold01 0.721 0.279 55 0 1 1e-10 50 PS
8 Fold01 0.381 0.619 57 0 1 1e-10 50 PS
9 Fold01 0.380 0.620 58 0 1 1e-10 50 PS
10 Fold01 0.721 0.279 60 0 1 1e-10 50 PS
# ... with 122,624 more rows, and 1 more variable: .config <chr>
mlp_tuning %>% show_best()
# A tibble: 5 x 10
hidden_units penalty epochs num_comp .metric .estimator mean n std_err
<int> <dbl> <int> <int> <chr> <chr> <dbl> <int> <dbl>
1 5 1 50 0 roc_auc binary 0.896 10 0.00632
2 5 1 125 40 roc_auc binary 0.895 10 0.00762
3 5 1 125 0 roc_auc binary 0.895 10 0.00623
4 5 1 50 20 roc_auc binary 0.895 10 0.00653
5 5 1 200 40 roc_auc binary 0.894 10 0.00569
# ... with 1 more variable: .config <chr>
mlp_tuning %>% select_best()
# A tibble: 1 x 5
hidden_units penalty epochs num_comp .config
<int> <dbl> <int> <int> <chr>
1 5 1 50 0 Preprocessor1_Model08
autoplot(mlp_tuning) + theme(legend.position = "top")
parallel::detectCores(logical = FALSE)
[1] 8
parallel::detectCores(logical = TRUE)
[1] 16
set.seed(2021)
mlp_tuning2 <-
tune_grid(
mlp_nnet_wf,
resamples = cell_folds,
grid = 20,
param_info = mlp_param,
metrics = roc_metric,
control = keep_pred
)
mlp_tuning2
# Tuning results
# 10-fold cross-validation
# A tibble: 10 x 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [1362/152~ Fold01 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,040 x 9~
2 <split [1362/152~ Fold02 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,040 x 9~
3 <split [1362/152~ Fold03 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,040 x 9~
4 <split [1362/152~ Fold04 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,040 x 9~
5 <split [1363/151~ Fold05 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,020 x 9~
6 <split [1363/151~ Fold06 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,020 x 9~
7 <split [1363/151~ Fold07 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,020 x 9~
8 <split [1363/151~ Fold08 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,020 x 9~
9 <split [1363/151~ Fold09 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,020 x 9~
10 <split [1363/151~ Fold10 <tibble [20 x 8~ <tibble [0 x 1~ <tibble [3,020 x 9~
mlp_tuning2 %>% collect_predictions()
# A tibble: 30,280 x 10
id .pred_PS .pred_WS .row num_comp hidden_units penalty epochs class
<chr> <dbl> <dbl> <int> <int> <int> <dbl> <int> <fct>
1 Fold01 0.720 0.280 15 14 6 0.00106 112 PS
2 Fold01 0.341 0.659 27 14 6 0.00106 112 PS
3 Fold01 0.716 0.284 29 14 6 0.00106 112 PS
4 Fold01 0.635 0.365 35 14 6 0.00106 112 PS
5 Fold01 0.349 0.651 50 14 6 0.00106 112 PS
6 Fold01 0.664 0.336 52 14 6 0.00106 112 PS
7 Fold01 0.720 0.280 55 14 6 0.00106 112 PS
8 Fold01 0.509 0.491 57 14 6 0.00106 112 PS
9 Fold01 0.341 0.659 58 14 6 0.00106 112 PS
10 Fold01 0.730 0.270 60 14 6 0.00106 112 PS
# ... with 30,270 more rows, and 1 more variable: .config <chr>
mlp_tuning2 %>% collect_metrics()
# A tibble: 20 x 10
hidden_units penalty epochs num_comp .metric .estimator mean n std_err
<int> <dbl> <int> <int> <chr> <chr> <dbl> <int> <dbl>
1 6 1.06e- 3 112 14 roc_auc binary 0.854 10 0.00960
2 6 4.31e- 8 135 39 roc_auc binary 0.868 10 0.00675
3 6 4.80e-10 147 15 roc_auc binary 0.855 10 0.0111
4 1 1.37e- 1 77 28 roc_auc binary 0.876 10 0.00559
5 5 4.81e- 6 105 18 roc_auc binary 0.846 10 0.00993
6 9 8.70e- 2 164 18 roc_auc binary 0.851 10 0.0101
7 5 1.43e- 4 184 33 roc_auc binary 0.856 10 0.00943
8 9 3.51e- 5 157 5 roc_auc binary 0.860 10 0.00631
9 7 5.82e- 4 70 7 roc_auc binary 0.866 10 0.00726
10 2 2.14e- 2 176 37 roc_auc binary 0.884 10 0.00694
11 10 1.59e-10 55 35 roc_auc binary 0.836 10 0.0118
12 4 2.60e- 6 100 3 roc_auc binary 0.870 10 0.00577
13 2 6.04e- 9 81 25 roc_auc binary 0.839 10 0.00736
14 2 1.04e- 8 189 31 roc_auc binary 0.853 10 0.00686
15 4 1.74e- 5 62 26 roc_auc binary 0.855 10 0.00812
16 3 1.95e- 9 199 22 roc_auc binary 0.860 10 0.00562
17 9 3.32e- 7 118 12 roc_auc binary 0.833 10 0.00749
18 8 3.96e- 3 128 1 roc_auc binary 0.769 10 0.00806
19 7 1.15e- 7 148 9 roc_auc binary 0.849 10 0.00965
20 3 6.59e- 1 93 20 roc_auc binary 0.892 10 0.00520
# ... with 1 more variable: .config <chr>
mlp_tuning2 %>% show_best(n= 10)
# A tibble: 10 x 10
hidden_units penalty epochs num_comp .metric .estimator mean n std_err
<int> <dbl> <int> <int> <chr> <chr> <dbl> <int> <dbl>
1 3 6.59e- 1 93 20 roc_auc binary 0.892 10 0.00520
2 2 2.14e- 2 176 37 roc_auc binary 0.884 10 0.00694
3 1 1.37e- 1 77 28 roc_auc binary 0.876 10 0.00559
4 4 2.60e- 6 100 3 roc_auc binary 0.870 10 0.00577
5 6 4.31e- 8 135 39 roc_auc binary 0.868 10 0.00675
6 7 5.82e- 4 70 7 roc_auc binary 0.866 10 0.00726
7 3 1.95e- 9 199 22 roc_auc binary 0.860 10 0.00562
8 9 3.51e- 5 157 5 roc_auc binary 0.860 10 0.00631
9 5 1.43e- 4 184 33 roc_auc binary 0.856 10 0.00943
10 6 4.80e-10 147 15 roc_auc binary 0.855 10 0.0111
# ... with 1 more variable: .config <chr>
mlp_tuning2 %>% select_best(.metric = "roc_auc")
# A tibble: 1 x 5
hidden_units penalty epochs num_comp .config
<int> <dbl> <int> <int> <chr>
1 3 0.659 93 20 Preprocessor19_Model1
autoplot(mlp_tuning2)
Finalizing the model. I choose mlp_tunning due to a better .metric (roc_auc) = 0.89
best_mlp_model <-
mlp_tuning %>%
select_best(metric = "roc_auc")
last_mlp_wf <-
mlp_nnet_wf %>%
finalize_workflow(best_mlp_model)
last_mlp_wf
== Workflow ====================================================================
Preprocessor: Recipe
Model: mlp()
-- Preprocessor ----------------------------------------------------------------
3 Recipe Steps
* step_YeoJohnson()
* step_normalize()
* step_pca()
-- Model -----------------------------------------------------------------------
Single Layer Neural Network Specification (classification)
Main Arguments:
hidden_units = 5
penalty = 1
epochs = 50
Engine-Specific Arguments:
trace = 0
Computational engine: nnet
mlp_fit <-
last_fit(
last_mlp_wf,
split = cell_split
)
mlp_fit
# Resampling results
# Manual resampling
# A tibble: 1 x 6
splits id .metrics .notes .predictions .workflow
<list> <chr> <list> <list> <list> <list>
1 <split [1514~ train/test ~ <tibble [2 x~ <tibble [0~ <tibble [505 x~ <workflo~
mlp_fit %>% collect_metrics()
# A tibble: 2 x 4
.metric .estimator .estimate .config
<chr> <chr> <dbl> <chr>
1 accuracy binary 0.798 Preprocessor1_Model1
2 roc_auc binary 0.879 Preprocessor1_Model1
mlp_pred <- mlp_fit %>% collect_predictions()
mlp_pred
# A tibble: 505 x 7
id .pred_PS .pred_WS .row .pred_class class .config
<chr> <dbl> <dbl> <int> <fct> <fct> <chr>
1 train/test split 0.393 0.607 3 WS WS Preprocessor1_Mod~
2 train/test split 0.621 0.379 9 PS WS Preprocessor1_Mod~
3 train/test split 0.716 0.284 10 PS WS Preprocessor1_Mod~
4 train/test split 0.308 0.692 11 WS WS Preprocessor1_Mod~
5 train/test split 0.726 0.274 13 PS PS Preprocessor1_Mod~
6 train/test split 0.401 0.599 16 WS PS Preprocessor1_Mod~
7 train/test split 0.460 0.540 22 WS PS Preprocessor1_Mod~
8 train/test split 0.725 0.275 24 PS PS Preprocessor1_Mod~
9 train/test split 0.583 0.417 30 PS WS Preprocessor1_Mod~
10 train/test split 0.326 0.674 35 WS PS Preprocessor1_Mod~
# ... with 495 more rows
mlp_pred %>%
conf_mat(
truth = class, estimate = .pred_class
)
Truth
Prediction PS WS
PS 273 50
WS 52 130
mlp_pred %>%
roc_auc(
truth = class,
estimate = .pred_PS
)
# A tibble: 1 x 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 roc_auc binary 0.879
mlp_pred %>%
roc_curve(
truth = class,
estimate = .pred_PS
) %>%
autoplot()
library(finetune)
set.seed(2021)
mlp_race <-
tune_race_anova(
mlp_nnet_wf,
resamples = cell_folds,
grid = 20,
param_info = mlp_param,
metrics = roc_metric,
control = control_race(verbose_elim = T)
)
mlp_race
# Tuning results
# 10-fold cross-validation
# A tibble: 10 x 5
splits id .order .metrics .notes
<list> <chr> <int> <list> <list>
1 <split [1362/152]> Fold01 3 <tibble [20 x 8]> <tibble [0 x 1]>
2 <split [1362/152]> Fold04 2 <tibble [20 x 8]> <tibble [0 x 1]>
3 <split [1363/151]> Fold08 1 <tibble [20 x 8]> <tibble [0 x 1]>
4 <split [1363/151]> Fold05 4 <tibble [13 x 8]> <tibble [0 x 1]>
5 <split [1363/151]> Fold07 5 <tibble [4 x 8]> <tibble [0 x 1]>
6 <split [1363/151]> Fold06 6 <tibble [3 x 8]> <tibble [0 x 1]>
7 <split [1362/152]> Fold03 7 <tibble [3 x 8]> <tibble [0 x 1]>
8 <split [1362/152]> Fold02 8 <tibble [3 x 8]> <tibble [0 x 1]>
9 <split [1363/151]> Fold09 9 <tibble [3 x 8]> <tibble [0 x 1]>
10 <split [1363/151]> Fold10 10 <tibble [3 x 8]> <tibble [0 x 1]>
mlp_race %>% collect_metrics()
# A tibble: 20 x 10
hidden_units penalty epochs num_comp .metric .estimator mean n std_err
<int> <dbl> <int> <int> <chr> <chr> <dbl> <int> <dbl>
1 1 2.07e- 9 101 15 roc_auc binary 0.840 3 0.0215
2 5 6.15e- 9 147 3 roc_auc binary 0.843 3 0.00837
3 8 1.29e- 7 114 13 roc_auc binary 0.848 4 0.00970
4 9 5.80e- 1 156 34 roc_auc binary 0.866 5 0.00955
5 4 1.39e-10 95 34 roc_auc binary 0.853 4 0.00626
6 10 3.34e- 5 122 0 roc_auc binary 0.823 3 0.0150
7 3 3.09e- 6 69 5 roc_auc binary 0.878 10 0.00650
8 7 5.63e- 2 148 19 roc_auc binary 0.854 4 0.00899
9 3 4.81e- 3 178 36 roc_auc binary 0.835 3 0.0237
10 7 8.50e- 4 63 29 roc_auc binary 0.800 3 0.0234
11 8 3.87e- 6 138 39 roc_auc binary 0.847 4 0.0120
12 8 1.49e- 1 200 10 roc_auc binary 0.873 10 0.00897
13 6 6.68e- 8 105 22 roc_auc binary 0.844 4 0.00980
14 2 1.15e- 2 128 6 roc_auc binary 0.882 10 0.00318
15 9 3.83e- 7 84 23 roc_auc binary 0.835 3 0.0344
16 2 2.50e- 8 56 25 roc_auc binary 0.850 4 0.00547
17 4 5.57e-10 192 18 roc_auc binary 0.827 3 0.0206
18 2 1.08e- 4 176 27 roc_auc binary 0.849 4 0.00941
19 5 3.03e- 3 169 31 roc_auc binary 0.849 4 0.0191
20 6 1.68e- 5 74 11 roc_auc binary 0.854 4 0.0169
# ... with 1 more variable: .config <chr>
mlp_race %>% show_best(metric = "roc_auc")
# A tibble: 3 x 10
hidden_units penalty epochs num_comp .metric .estimator mean n std_err
<int> <dbl> <int> <int> <chr> <chr> <dbl> <int> <dbl>
1 2 0.0115 128 6 roc_auc binary 0.882 10 0.00318
2 3 0.00000309 69 5 roc_auc binary 0.878 10 0.00650
3 8 0.149 200 10 roc_auc binary 0.873 10 0.00897
# ... with 1 more variable: .config <chr>
best_race <-
mlp_race %>% select_best(metric = "roc_auc")
last_race_wf <-
mlp_nnet_wf %>%
finalize_workflow(best_race)
race_fit <-
last_fit(
last_race_wf,
split = cell_split
)
race_fit %>% collect_predictions()
# A tibble: 505 x 7
id .pred_PS .pred_WS .row .pred_class class .config
<chr> <dbl> <dbl> <int> <fct> <fct> <chr>
1 train/test split 0.441 0.559 3 WS WS Preprocessor1_Mod~
2 train/test split 0.624 0.376 9 PS WS Preprocessor1_Mod~
3 train/test split 0.707 0.293 10 PS WS Preprocessor1_Mod~
4 train/test split 0.328 0.672 11 WS WS Preprocessor1_Mod~
5 train/test split 0.728 0.272 13 PS PS Preprocessor1_Mod~
6 train/test split 0.427 0.573 16 WS PS Preprocessor1_Mod~
7 train/test split 0.477 0.523 22 WS PS Preprocessor1_Mod~
8 train/test split 0.730 0.270 24 PS PS Preprocessor1_Mod~
9 train/test split 0.596 0.404 30 PS WS Preprocessor1_Mod~
10 train/test split 0.323 0.677 35 WS PS Preprocessor1_Mod~
# ... with 495 more rows
race_fit %>% collect_metrics()
# A tibble: 2 x 4
.metric .estimator .estimate .config
<chr> <chr> <dbl> <chr>
1 accuracy binary 0.804 Preprocessor1_Model1
2 roc_auc binary 0.879 Preprocessor1_Model1