library(tidymodels)
library(ISLR)
theme_set(theme_bw())
set.seed(2021)
sim_data <- tibble(
x1 = rnorm(40),
x2 = rnorm(40),
y = factor(rep(c(-1, 1), 20))
) %>%
mutate(x1 = ifelse(y == 1, x1 + 1.5, x1),
x2 = ifelse(y == 1, x2 + 1.5, x2)
)
ggplot(sim_data, aes(x1, x2, color = y)) +
geom_point()
svm_linear_model <-
svm_poly(degree = 1) %>%
set_mode("classification") %>%
set_engine("kernlab", scale = F)
svm_linear_fit <-
svm_linear_model %>%
set_args(cost = 10) %>%
fit(y ~ ., data = sim_data)
svm_linear_fit
parsnip model object
Fit time: 560ms
Support Vector Machine object of class "ksvm"
SV type: C-svc (classification)
parameter : cost C = 10
Polynomial kernel function.
Hyperparameters : degree = 1 scale = 1 offset = 1
Number of Support Vectors : 15
Objective Function Value : -124.6593
Training error : 0.15
Probability model included.
library(kernlab)
svm_linear_fit %>%
extract_fit_engine() %>%
plot()
svm_linear_fit <-
svm_linear_model %>%
set_args(cost = .1) %>%
fit(y ~ ., data = sim_data)
svm_linear_fit
parsnip model object
Fit time: 21ms
Support Vector Machine object of class "ksvm"
SV type: C-svc (classification)
parameter : cost C = 0.1
Polynomial kernel function.
Hyperparameters : degree = 1 scale = 1 offset = 1
Number of Support Vectors : 27
Objective Function Value : -2.0266
Training error : 0.125
Probability model included.
svm_linear_fit %>%
extract_fit_engine() %>%
plot()
Now that a smaller value of the cost parameter is being used, we obtain a larger number of support vectors, because the margin is now wider.
linear_svm_wf <- workflow() %>%
add_model(svm_linear_model %>% set_args(cost = tune())) %>%
add_formula(y ~ .)
set.seed(2021)
sim_data_fold <- vfold_cv(sim_data, strata = y)
param_grid <- grid_regular(cost(), levels = 10)
tune_res <- tune_grid(
linear_svm_wf,
resamples = sim_data_fold,
grid = param_grid
)
autoplot(tune_res)
tune_res %>% collect_metrics()
# A tibble: 20 x 7
cost .metric .estimator mean n std_err .config
<dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 0.000977 accuracy binary 0.8 10 0.0624 Preprocessor1_Model01
2 0.000977 roc_auc binary 0.925 10 0.0382 Preprocessor1_Model01
3 0.00310 accuracy binary 0.8 10 0.0624 Preprocessor1_Model02
4 0.00310 roc_auc binary 0.925 10 0.0382 Preprocessor1_Model02
5 0.00984 accuracy binary 0.8 10 0.0624 Preprocessor1_Model03
6 0.00984 roc_auc binary 0.925 10 0.0382 Preprocessor1_Model03
7 0.0312 accuracy binary 0.825 10 0.0534 Preprocessor1_Model04
8 0.0312 roc_auc binary 0.925 10 0.0382 Preprocessor1_Model04
9 0.0992 accuracy binary 0.85 10 0.0408 Preprocessor1_Model05
10 0.0992 roc_auc binary 0.925 10 0.0382 Preprocessor1_Model05
11 0.315 accuracy binary 0.85 10 0.0553 Preprocessor1_Model06
12 0.315 roc_auc binary 0.925 10 0.0382 Preprocessor1_Model06
13 1 accuracy binary 0.85 10 0.0553 Preprocessor1_Model07
14 1 roc_auc binary 0.925 10 0.0382 Preprocessor1_Model07
15 3.17 accuracy binary 0.85 10 0.0553 Preprocessor1_Model08
16 3.17 roc_auc binary 0.9 10 0.0553 Preprocessor1_Model08
17 10.1 accuracy binary 0.85 10 0.0553 Preprocessor1_Model09
18 10.1 roc_auc binary 0.9 10 0.0553 Preprocessor1_Model09
19 32 accuracy binary 0.85 10 0.0553 Preprocessor1_Model10
20 32 roc_auc binary 0.9 10 0.0553 Preprocessor1_Model10
tune_res %>% show_best(metric = "accuracy")
# A tibble: 5 x 7
cost .metric .estimator mean n std_err .config
<dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 0.0992 accuracy binary 0.85 10 0.0408 Preprocessor1_Model05
2 0.315 accuracy binary 0.85 10 0.0553 Preprocessor1_Model06
3 1 accuracy binary 0.85 10 0.0553 Preprocessor1_Model07
4 3.17 accuracy binary 0.85 10 0.0553 Preprocessor1_Model08
5 10.1 accuracy binary 0.85 10 0.0553 Preprocessor1_Model09
tune_best <- tune_res %>% select_best(metric = "accuracy")
svm_final_wf <-
linear_svm_wf %>%
finalize_workflow(tune_best)
linear_svm_fit <- fit(svm_final_wf, sim_data)
set.seed(2)
sim_data_test <- tibble(
x1 = rnorm(20),
x2 = rnorm(20),
y = factor(rep(c(-1, 1), 10))
) %>%
mutate(x1 = ifelse(y == 1, x1 + 1.5, x1),
x2 = ifelse(y == 1, x2 + 1.5, x2))
pred <- augment(linear_svm_fit, sim_data_test)
pred %>% conf_mat(truth = y, estimate = .pred_class)
Truth
Prediction -1 1
-1 7 2
1 3 8
pred %>% accuracy(truth = y, estimate = .pred_class)
# A tibble: 1 x 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 accuracy binary 0.75
set.seed(1)
sim_data2 <- tibble(
x1 = rnorm(200) + rep(c(2, -2, 0), c(100, 50, 50)),
x2 = rnorm(200) + rep(c(2, -2, 0), c(100, 50, 50)),
y = factor(rep(c(1, 2), c(150, 50)))
)
sim_data2 %>%
ggplot(aes(x1, x2, color = y)) +
geom_point()
svm_rbf_spec <- svm_rbf() %>%
set_mode("classification") %>%
set_engine("kernlab")
svm_rbf_fit <- svm_rbf_spec %>%
fit(y ~ ., data = sim_data2)
svm_rbf_fit %>%
extract_fit_engine() %>%
plot()
set.seed(2)
sim_data2_test <- tibble(
x1 = rnorm(200) + rep(c(2, -2, 0), c(100, 50, 50)),
x2 = rnorm(200) + rep(c(2, -2, 0), c(100, 50, 50)),
y = factor(rep(c(1, 2), c(150, 50)))
)
augment(svm_rbf_fit, new_data = sim_data2_test) %>%
conf_mat(truth = y, estimate = .pred_class)
Truth
Prediction 1 2
1 137 7
2 13 43
augment(svm_rbf_fit, new_data = sim_data2_test) %>%
roc_curve(truth = y, estimate = .pred_1)
# A tibble: 202 x 3
.threshold specificity sensitivity
<dbl> <dbl> <dbl>
1 -Inf 0 1
2 0.104 0 1
3 0.113 0.02 1
4 0.114 0.04 1
5 0.115 0.06 1
6 0.117 0.08 1
7 0.118 0.1 1
8 0.119 0.12 1
9 0.124 0.14 1
10 0.124 0.14 0.993
# ... with 192 more rows
augment(svm_rbf_fit, new_data = sim_data2_test) %>%
roc_curve(truth = y, estimate = .pred_1) %>%
autoplot()
augment(svm_rbf_fit, new_data = sim_data2_test) %>%
roc_auc(truth = y, estimate = .pred_1)
# A tibble: 1 x 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 roc_auc binary 0.925
An Introduction to Statistcial Learning
– END
sessionInfo()
R version 4.1.0 (2021-05-18)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19043)
Matrix products: default
locale:
[1] LC_COLLATE=Spanish_Mexico.1252 LC_CTYPE=Spanish_Mexico.1252
[3] LC_MONETARY=Spanish_Mexico.1252 LC_NUMERIC=C
[5] LC_TIME=Spanish_Mexico.1252
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] vctrs_0.3.8 rlang_0.4.11 kernlab_0.9-29 ISLR_1.2
[5] yardstick_0.0.8 workflowsets_0.0.2 workflows_0.2.3 tune_0.1.6
[9] tidyr_1.1.3 tibble_3.1.3 rsample_0.1.0 recipes_0.1.16
[13] purrr_0.3.4 parsnip_0.1.7 modeldata_0.1.1 infer_0.5.4
[17] ggplot2_3.3.5 dplyr_1.0.7 dials_0.0.9 scales_1.1.1
[21] broom_0.7.8 tidymodels_0.1.3
loaded via a namespace (and not attached):
[1] sass_0.4.0 jsonlite_1.7.2 splines_4.1.0 foreach_1.5.1
[5] prodlim_2019.11.13 bslib_0.2.5.1 assertthat_0.2.1 highr_0.9
[9] GPfit_1.0-8 yaml_2.2.1 globals_0.14.0 ipred_0.9-11
[13] pillar_1.6.2 backports_1.2.1 lattice_0.20-44 glue_1.4.2
[17] pROC_1.17.0.1 digest_0.6.27 hardhat_0.1.6 colorspace_2.0-2
[21] plyr_1.8.6 htmltools_0.5.1.1 Matrix_1.3-3 timeDate_3043.102
[25] pkgconfig_2.0.3 lhs_1.1.1 DiceDesign_1.9 listenv_0.8.0
[29] gower_0.2.2 lava_1.6.9 farver_2.1.0 generics_0.1.0
[33] ellipsis_0.3.2 withr_2.4.2 furrr_0.2.3 nnet_7.3-16
[37] cli_3.0.0 survival_3.2-11 magrittr_2.0.1 crayon_1.4.1
[41] evaluate_0.14 future_1.21.0 fansi_0.5.0 parallelly_1.26.1
[45] MASS_7.3-54 class_7.3-19 prettyunits_1.1.1 tools_4.1.0
[49] lifecycle_1.0.0 stringr_1.4.0 munsell_0.5.0 compiler_4.1.0
[53] jquerylib_0.1.4 grid_4.1.0 rstudioapi_0.13 iterators_1.0.13
[57] labeling_0.4.2 rmarkdown_2.9 gtable_0.3.0 codetools_0.2-18
[61] DBI_1.1.1 R6_2.5.0 lubridate_1.7.10 knitr_1.33
[65] utf8_1.2.2 stringi_1.6.2 parallel_4.1.0 Rcpp_1.0.7
[69] rpart_4.1-15 tidyselect_1.1.1 xfun_0.24