An Introduction to Statistical Learning (2nd ed.)

Chapter 09

Support Vector Machines

library(tidymodels)
library(ISLR)

theme_set(theme_bw())

Support Vectors Classifier

set.seed(2021)

sim_data <- tibble(
  x1 = rnorm(40),
  x2 = rnorm(40),
  y = factor(rep(c(-1, 1), 20))
  ) %>% 
  mutate(x1 = ifelse(y == 1, x1 + 1.5, x1),
         x2 = ifelse(y == 1, x2 + 1.5, x2)
         )

ggplot(sim_data, aes(x1, x2, color = y)) + 
  geom_point()

svm_linear_model <- 
  svm_poly(degree = 1) %>% 
  set_mode("classification") %>% 
  set_engine("kernlab", scale = F)

svm_linear_fit <- 
  svm_linear_model %>% 
  set_args(cost = 10) %>% 
  fit(y ~ ., data = sim_data)


svm_linear_fit
parsnip model object

Fit time:  560ms 
Support Vector Machine object of class "ksvm" 

SV type: C-svc  (classification) 
 parameter : cost C = 10 

Polynomial kernel function. 
 Hyperparameters : degree =  1  scale =  1  offset =  1 

Number of Support Vectors : 15 

Objective Function Value : -124.6593 
Training error : 0.15 
Probability model included. 
library(kernlab)

svm_linear_fit %>% 
  extract_fit_engine() %>% 
  plot()

svm_linear_fit <- 
  svm_linear_model %>% 
  set_args(cost = .1) %>% 
  fit(y ~ ., data = sim_data)


svm_linear_fit
parsnip model object

Fit time:  21ms 
Support Vector Machine object of class "ksvm" 

SV type: C-svc  (classification) 
 parameter : cost C = 0.1 

Polynomial kernel function. 
 Hyperparameters : degree =  1  scale =  1  offset =  1 

Number of Support Vectors : 27 

Objective Function Value : -2.0266 
Training error : 0.125 
Probability model included. 
svm_linear_fit %>% 
  extract_fit_engine() %>% 
  plot()

Now that a smaller value of the cost parameter is being used, we obtain a larger number of support vectors, because the margin is now wider.

linear_svm_wf <- workflow() %>%
  add_model(svm_linear_model %>% set_args(cost = tune())) %>%
  add_formula(y ~ .)

set.seed(2021)
sim_data_fold <- vfold_cv(sim_data, strata = y)

param_grid <- grid_regular(cost(), levels = 10)

tune_res <- tune_grid(
  linear_svm_wf, 
  resamples = sim_data_fold, 
  grid = param_grid
)

autoplot(tune_res)

tune_res %>% collect_metrics()
# A tibble: 20 x 7
        cost .metric  .estimator  mean     n std_err .config              
       <dbl> <chr>    <chr>      <dbl> <int>   <dbl> <chr>                
 1  0.000977 accuracy binary     0.8      10  0.0624 Preprocessor1_Model01
 2  0.000977 roc_auc  binary     0.925    10  0.0382 Preprocessor1_Model01
 3  0.00310  accuracy binary     0.8      10  0.0624 Preprocessor1_Model02
 4  0.00310  roc_auc  binary     0.925    10  0.0382 Preprocessor1_Model02
 5  0.00984  accuracy binary     0.8      10  0.0624 Preprocessor1_Model03
 6  0.00984  roc_auc  binary     0.925    10  0.0382 Preprocessor1_Model03
 7  0.0312   accuracy binary     0.825    10  0.0534 Preprocessor1_Model04
 8  0.0312   roc_auc  binary     0.925    10  0.0382 Preprocessor1_Model04
 9  0.0992   accuracy binary     0.85     10  0.0408 Preprocessor1_Model05
10  0.0992   roc_auc  binary     0.925    10  0.0382 Preprocessor1_Model05
11  0.315    accuracy binary     0.85     10  0.0553 Preprocessor1_Model06
12  0.315    roc_auc  binary     0.925    10  0.0382 Preprocessor1_Model06
13  1        accuracy binary     0.85     10  0.0553 Preprocessor1_Model07
14  1        roc_auc  binary     0.925    10  0.0382 Preprocessor1_Model07
15  3.17     accuracy binary     0.85     10  0.0553 Preprocessor1_Model08
16  3.17     roc_auc  binary     0.9      10  0.0553 Preprocessor1_Model08
17 10.1      accuracy binary     0.85     10  0.0553 Preprocessor1_Model09
18 10.1      roc_auc  binary     0.9      10  0.0553 Preprocessor1_Model09
19 32        accuracy binary     0.85     10  0.0553 Preprocessor1_Model10
20 32        roc_auc  binary     0.9      10  0.0553 Preprocessor1_Model10
tune_res %>% show_best(metric = "accuracy")
# A tibble: 5 x 7
     cost .metric  .estimator  mean     n std_err .config              
    <dbl> <chr>    <chr>      <dbl> <int>   <dbl> <chr>                
1  0.0992 accuracy binary      0.85    10  0.0408 Preprocessor1_Model05
2  0.315  accuracy binary      0.85    10  0.0553 Preprocessor1_Model06
3  1      accuracy binary      0.85    10  0.0553 Preprocessor1_Model07
4  3.17   accuracy binary      0.85    10  0.0553 Preprocessor1_Model08
5 10.1    accuracy binary      0.85    10  0.0553 Preprocessor1_Model09
tune_best <- tune_res %>% select_best(metric = "accuracy")
svm_final_wf <- 
  linear_svm_wf %>% 
  finalize_workflow(tune_best)

linear_svm_fit <- fit(svm_final_wf, sim_data)
set.seed(2)
sim_data_test <- tibble(
  x1 = rnorm(20),
  x2 = rnorm(20),
  y  = factor(rep(c(-1, 1), 10))
) %>%
  mutate(x1 = ifelse(y == 1, x1 + 1.5, x1),
         x2 = ifelse(y == 1, x2 + 1.5, x2))
pred <- augment(linear_svm_fit, sim_data_test)

pred %>% conf_mat(truth = y, estimate = .pred_class)
          Truth
Prediction -1 1
        -1  7 2
        1   3 8
pred %>% accuracy(truth = y, estimate = .pred_class)
# A tibble: 1 x 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy binary          0.75

Support Vector Machine

set.seed(1)
sim_data2 <- tibble(
  x1 = rnorm(200) + rep(c(2, -2, 0), c(100, 50, 50)),
  x2 = rnorm(200) + rep(c(2, -2, 0), c(100, 50, 50)),
  y  = factor(rep(c(1, 2), c(150, 50)))
)

sim_data2 %>%
  ggplot(aes(x1, x2, color = y)) +
  geom_point()

svm_rbf_spec <- svm_rbf() %>%
  set_mode("classification") %>%
  set_engine("kernlab")
svm_rbf_fit <- svm_rbf_spec %>%
  fit(y ~ ., data = sim_data2)

svm_rbf_fit %>%
  extract_fit_engine() %>%
  plot()

set.seed(2)
sim_data2_test <- tibble(
  x1 = rnorm(200) + rep(c(2, -2, 0), c(100, 50, 50)),
  x2 = rnorm(200) + rep(c(2, -2, 0), c(100, 50, 50)),
  y  = factor(rep(c(1, 2), c(150, 50)))
)
augment(svm_rbf_fit, new_data = sim_data2_test) %>%
  conf_mat(truth = y, estimate = .pred_class)
          Truth
Prediction   1   2
         1 137   7
         2  13  43
augment(svm_rbf_fit, new_data = sim_data2_test) %>%
  roc_curve(truth = y, estimate = .pred_1)
# A tibble: 202 x 3
   .threshold specificity sensitivity
        <dbl>       <dbl>       <dbl>
 1   -Inf            0          1    
 2      0.104        0          1    
 3      0.113        0.02       1    
 4      0.114        0.04       1    
 5      0.115        0.06       1    
 6      0.117        0.08       1    
 7      0.118        0.1        1    
 8      0.119        0.12       1    
 9      0.124        0.14       1    
10      0.124        0.14       0.993
# ... with 192 more rows
augment(svm_rbf_fit, new_data = sim_data2_test) %>%
  roc_curve(truth = y, estimate = .pred_1) %>%
  autoplot()

augment(svm_rbf_fit, new_data = sim_data2_test) %>%
  roc_auc(truth = y, estimate = .pred_1)
# A tibble: 1 x 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 roc_auc binary         0.925

An Introduction to Statistcial Learning

ISLR tidymodels Labs

– END

sessionInfo()
R version 4.1.0 (2021-05-18)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19043)

Matrix products: default

locale:
[1] LC_COLLATE=Spanish_Mexico.1252  LC_CTYPE=Spanish_Mexico.1252   
[3] LC_MONETARY=Spanish_Mexico.1252 LC_NUMERIC=C                   
[5] LC_TIME=Spanish_Mexico.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] vctrs_0.3.8        rlang_0.4.11       kernlab_0.9-29     ISLR_1.2          
 [5] yardstick_0.0.8    workflowsets_0.0.2 workflows_0.2.3    tune_0.1.6        
 [9] tidyr_1.1.3        tibble_3.1.3       rsample_0.1.0      recipes_0.1.16    
[13] purrr_0.3.4        parsnip_0.1.7      modeldata_0.1.1    infer_0.5.4       
[17] ggplot2_3.3.5      dplyr_1.0.7        dials_0.0.9        scales_1.1.1      
[21] broom_0.7.8        tidymodels_0.1.3  

loaded via a namespace (and not attached):
 [1] sass_0.4.0         jsonlite_1.7.2     splines_4.1.0      foreach_1.5.1     
 [5] prodlim_2019.11.13 bslib_0.2.5.1      assertthat_0.2.1   highr_0.9         
 [9] GPfit_1.0-8        yaml_2.2.1         globals_0.14.0     ipred_0.9-11      
[13] pillar_1.6.2       backports_1.2.1    lattice_0.20-44    glue_1.4.2        
[17] pROC_1.17.0.1      digest_0.6.27      hardhat_0.1.6      colorspace_2.0-2  
[21] plyr_1.8.6         htmltools_0.5.1.1  Matrix_1.3-3       timeDate_3043.102 
[25] pkgconfig_2.0.3    lhs_1.1.1          DiceDesign_1.9     listenv_0.8.0     
[29] gower_0.2.2        lava_1.6.9         farver_2.1.0       generics_0.1.0    
[33] ellipsis_0.3.2     withr_2.4.2        furrr_0.2.3        nnet_7.3-16       
[37] cli_3.0.0          survival_3.2-11    magrittr_2.0.1     crayon_1.4.1      
[41] evaluate_0.14      future_1.21.0      fansi_0.5.0        parallelly_1.26.1 
[45] MASS_7.3-54        class_7.3-19       prettyunits_1.1.1  tools_4.1.0       
[49] lifecycle_1.0.0    stringr_1.4.0      munsell_0.5.0      compiler_4.1.0    
[53] jquerylib_0.1.4    grid_4.1.0         rstudioapi_0.13    iterators_1.0.13  
[57] labeling_0.4.2     rmarkdown_2.9      gtable_0.3.0       codetools_0.2-18  
[61] DBI_1.1.1          R6_2.5.0           lubridate_1.7.10   knitr_1.33        
[65] utf8_1.2.2         stringi_1.6.2      parallel_4.1.0     Rcpp_1.0.7        
[69] rpart_4.1-15       tidyselect_1.1.1   xfun_0.24