library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.4.1 ──
## ✔ broom        1.0.10     ✔ recipes      1.3.1 
## ✔ dials        1.4.2      ✔ rsample      1.3.1 
## ✔ dplyr        1.1.4      ✔ tailor       0.1.0 
## ✔ ggplot2      4.0.0      ✔ tidyr        1.3.1 
## ✔ infer        1.0.9      ✔ tune         2.0.1 
## ✔ modeldata    1.5.1      ✔ workflows    1.3.0 
## ✔ parsnip      1.3.3      ✔ workflowsets 1.1.1 
## ✔ purrr        1.2.0      ✔ yardstick    1.3.2
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter()  masks stats::filter()
## ✖ dplyr::lag()     masks stats::lag()
## ✖ recipes::step()  masks stats::step()
applications <- ISLR2::College
head(applications)
##                              Private Apps Accept Enroll Top10perc Top25perc
## Abilene Christian University     Yes 1660   1232    721        23        52
## Adelphi University               Yes 2186   1924    512        16        29
## Adrian College                   Yes 1428   1097    336        22        50
## Agnes Scott College              Yes  417    349    137        60        89
## Alaska Pacific University        Yes  193    146     55        16        44
## Albertson College                Yes  587    479    158        38        62
##                              F.Undergrad P.Undergrad Outstate Room.Board Books
## Abilene Christian University        2885         537     7440       3300   450
## Adelphi University                  2683        1227    12280       6450   750
## Adrian College                      1036          99    11250       3750   400
## Agnes Scott College                  510          63    12960       5450   450
## Alaska Pacific University            249         869     7560       4120   800
## Albertson College                    678          41    13500       3335   500
##                              Personal PhD Terminal S.F.Ratio perc.alumni Expend
## Abilene Christian University     2200  70       78      18.1          12   7041
## Adelphi University               1500  29       30      12.2          16  10527
## Adrian College                   1165  53       66      12.9          30   8735
## Agnes Scott College               875  92       97       7.7          37  19016
## Alaska Pacific University        1500  76       72      11.9           2  10922
## Albertson College                 675  67       73       9.4          11   9727
##                              Grad.Rate
## Abilene Christian University        60
## Adelphi University                  56
## Adrian College                      54
## Agnes Scott College                 59
## Alaska Pacific University           15
## Albertson College                   55
corrplot::corrplot(cor(applications %>% select(!Private)))

lm <- parsnip::linear_reg()
lasso <- parsnip::linear_reg(mixture = 1.0, penalty = tune()) %>% set_engine(engine = "glmnet", standardize = TRUE)
ridge <- parsnip::linear_reg(mixture = 0.0, penalty = tune()) %>% set_engine(engine = "glmnet", standardize = TRUE)
recipe <- recipes::recipe(Apps ~., data=applications) %>% 
          recipes::step_dummy(all_nominal_predictors()) %>% 
          recipes::step_normalize(all_numeric_predictors())
wf_lm <- workflow() %>%
  add_model(lm) %>%
  add_recipe(recipe)

wf_ridge <- workflow() %>%
  add_model(ridge) %>%
  add_recipe(recipe)

wf_lasso <- workflow() %>%
  add_model(lasso) %>%
  add_recipe(recipe)
split <- initial_split(applications)
train <- training(split)
test  <- testing(split)
cv <- rsample::vfold_cv(train)

cv_results_lm <- fit_resamples(
    wf_lm,
    resamples = cv,
    metrics = yardstick::metric_set(rmse, mape, mae),
    control = control_resamples(save_pred = TRUE)
  )
collect_metrics(cv_results_lm)
## # A tibble: 3 × 6
##   .metric .estimator   mean     n std_err .config        
##   <chr>   <chr>       <dbl> <int>   <dbl> <chr>          
## 1 mae     standard    676.     10   33.9  pre0_mod0_post0
## 2 mape    standard     40.8    10    1.80 pre0_mod0_post0
## 3 rmse    standard   1177.     10  105.   pre0_mod0_post0
full_lm <- wf_lm %>% fit(train)
grid_ridge <- grid_regular(
  penalty(range = c(-10, 3)),
  levels = 100
)
grid_lasso <- grid_regular(
  penalty(range = c(-6, 3)),
  levels = 100
)



cv <- rsample::vfold_cv(train)

cv_results_ridge <- wf_ridge %>%
    tune_grid(
      grid = grid_ridge,
      resamples = cv,
      metrics = yardstick::metric_set(rmse, mape, mae),
)
metrics_ridge <- collect_metrics(cv_results_ridge)

cv_results_lasso <- wf_lasso %>%
    tune_grid(
      grid = grid_lasso,
      resamples = cv,
      metrics = yardstick::metric_set(rmse, mape, mae),
)
metrics_lasso <- collect_metrics(cv_results_lasso)
show_best(cv_results_ridge, metric = "rmse")
## # A tibble: 5 × 7
##    penalty .metric .estimator  mean     n std_err .config          
##      <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>            
## 1 1   e-10 rmse    standard   1243.    10    204. pre0_mod001_post0
## 2 1.35e-10 rmse    standard   1243.    10    204. pre0_mod002_post0
## 3 1.83e-10 rmse    standard   1243.    10    204. pre0_mod003_post0
## 4 2.48e-10 rmse    standard   1243.    10    204. pre0_mod004_post0
## 5 3.35e-10 rmse    standard   1243.    10    204. pre0_mod005_post0
show_best(cv_results_lasso, metric = "rmse")
## # A tibble: 5 × 7
##      penalty .metric .estimator  mean     n std_err .config          
##        <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>            
## 1 0.000001   rmse    standard   1182.    10    114. pre0_mod001_post0
## 2 0.00000123 rmse    standard   1182.    10    114. pre0_mod002_post0
## 3 0.00000152 rmse    standard   1182.    10    114. pre0_mod003_post0
## 4 0.00000187 rmse    standard   1182.    10    114. pre0_mod004_post0
## 5 0.00000231 rmse    standard   1182.    10    114. pre0_mod005_post0
collect_metrics(cv_results_ridge)
## # A tibble: 300 × 7
##     penalty .metric .estimator   mean     n std_err .config          
##       <dbl> <chr>   <chr>       <dbl> <int>   <dbl> <chr>            
##  1 1   e-10 mae     standard    694.     10   46.7  pre0_mod001_post0
##  2 1   e-10 mape    standard     41.3    10    2.06 pre0_mod001_post0
##  3 1   e-10 rmse    standard   1243.     10  204.   pre0_mod001_post0
##  4 1.35e-10 mae     standard    694.     10   46.7  pre0_mod002_post0
##  5 1.35e-10 mape    standard     41.3    10    2.06 pre0_mod002_post0
##  6 1.35e-10 rmse    standard   1243.     10  204.   pre0_mod002_post0
##  7 1.83e-10 mae     standard    694.     10   46.7  pre0_mod003_post0
##  8 1.83e-10 mape    standard     41.3    10    2.06 pre0_mod003_post0
##  9 1.83e-10 rmse    standard   1243.     10  204.   pre0_mod003_post0
## 10 2.48e-10 mae     standard    694.     10   46.7  pre0_mod004_post0
## # ℹ 290 more rows
collect_metrics(cv_results_ridge) %>%
  filter(.metric == "rmse") %>%
  ggplot(aes(x=penalty, y=mean)) +
  geom_point() +
  geom_line()

collect_metrics(cv_results_ridge)
## # A tibble: 300 × 7
##     penalty .metric .estimator   mean     n std_err .config          
##       <dbl> <chr>   <chr>       <dbl> <int>   <dbl> <chr>            
##  1 1   e-10 mae     standard    694.     10   46.7  pre0_mod001_post0
##  2 1   e-10 mape    standard     41.3    10    2.06 pre0_mod001_post0
##  3 1   e-10 rmse    standard   1243.     10  204.   pre0_mod001_post0
##  4 1.35e-10 mae     standard    694.     10   46.7  pre0_mod002_post0
##  5 1.35e-10 mape    standard     41.3    10    2.06 pre0_mod002_post0
##  6 1.35e-10 rmse    standard   1243.     10  204.   pre0_mod002_post0
##  7 1.83e-10 mae     standard    694.     10   46.7  pre0_mod003_post0
##  8 1.83e-10 mape    standard     41.3    10    2.06 pre0_mod003_post0
##  9 1.83e-10 rmse    standard   1243.     10  204.   pre0_mod003_post0
## 10 2.48e-10 mae     standard    694.     10   46.7  pre0_mod004_post0
## # ℹ 290 more rows
collect_metrics(cv_results_ridge) %>%
  filter(.metric == "rmse") %>%
  filter(penalty < 250) %>% 
  ggplot(aes(x=penalty, y=mean)) +
  geom_point() +
  geom_line()

collect_metrics(cv_results_lasso) %>%
  filter(.metric == "rmse") %>%
  filter(penalty < 50) %>% 
  ggplot(aes(x=penalty, y=mean)) +
  geom_point() + 
geom_line()

select_best(cv_results_ridge, metric="rmse")
## # A tibble: 1 × 2
##        penalty .config          
##          <dbl> <chr>            
## 1 0.0000000001 pre0_mod001_post0
best_wf <- finalize_workflow(wf_ridge, select_best(cv_results_ridge, metric="rmse")) %>% fit(train)
best_wf %>% extract_fit_parsnip %>% tidy()
## 
## Attaching package: 'Matrix'
## The following objects are masked from 'package:tidyr':
## 
##     expand, pack, unpack
## Loaded glmnet 4.1-10
## # A tibble: 18 × 3
##    term        estimate      penalty
##    <chr>          <dbl>        <dbl>
##  1 (Intercept)  3109.   0.0000000001
##  2 Accept       2636.   0.0000000001
##  3 Enroll        337.   0.0000000001
##  4 Top10perc     557.   0.0000000001
##  5 Top25perc      -9.16 0.0000000001
##  6 F.Undergrad   358.   0.0000000001
##  7 P.Undergrad    23.9  0.0000000001
##  8 Outstate     -119.   0.0000000001
##  9 Room.Board    243.   0.0000000001
## 10 Books           3.75 0.0000000001
## 11 Personal        4.74 0.0000000001
## 12 PhD           -88.1  0.0000000001
## 13 Terminal      -72.0  0.0000000001
## 14 S.F.Ratio      75.1  0.0000000001
## 15 perc.alumni   -79.4  0.0000000001
## 16 Expend        414.   0.0000000001
## 17 Grad.Rate     201.   0.0000000001
## 18 Private_Yes  -263.   0.0000000001
best_wf <- finalize_workflow(wf_lasso, select_best(cv_results_lasso, metric="rmse")) %>% fit(train)
best_wf %>% extract_fit_parsnip %>% tidy()
## # A tibble: 18 × 3
##    term        estimate  penalty
##    <chr>          <dbl>    <dbl>
##  1 (Intercept)   3109.  0.000001
##  2 Accept        4083.  0.000001
##  3 Enroll       -1080.  0.000001
##  4 Top10perc     1086.  0.000001
##  5 Top25perc     -392.  0.000001
##  6 F.Undergrad    484.  0.000001
##  7 P.Undergrad     61.6 0.000001
##  8 Outstate      -379.  0.000001
##  9 Room.Board     166.  0.000001
## 10 Books          -17.1 0.000001
## 11 Personal        11.9 0.000001
## 12 PhD           -161.  0.000001
## 13 Terminal       -34.8 0.000001
## 14 S.F.Ratio       77.9 0.000001
## 15 perc.alumni     22.3 0.000001
## 16 Expend         418.  0.000001
## 17 Grad.Rate      161.  0.000001
## 18 Private_Yes   -223.  0.000001
full_lm %>% extract_fit_parsnip() %>% tidy()
## # A tibble: 18 × 5
##    term        estimate std.error statistic   p.value
##    <chr>          <dbl>     <dbl>     <dbl>     <dbl>
##  1 (Intercept)   3109.       45.1    68.9   7.11e-277
##  2 Accept        4123.      116.     35.5   1.16e-145
##  3 Enroll       -1205.      217.     -5.55  4.36e-  8
##  4 Top10perc     1121.      123.      9.11  1.46e- 18
##  5 Top25perc     -425.      110.     -3.84  1.35e-  4
##  6 F.Undergrad    573.      198.      2.90  3.89e-  3
##  7 P.Undergrad     61.9      58.3     1.06  2.89e-  1
##  8 Outstate      -392.       92.3    -4.24  2.56e-  5
##  9 Room.Board     168.       65.2     2.58  1.01e-  2
## 10 Books          -20.9      48.3    -0.433 6.65e-  1
## 11 Personal        13.4      52.4     0.255 7.99e-  1
## 12 PhD           -165.       94.5    -1.74  8.16e-  2
## 13 Terminal       -36.0      92.8    -0.388 6.98e-  1
## 14 S.F.Ratio       80.9      63.2     1.28  2.01e-  1
## 15 perc.alumni     31.3      61.9     0.505 6.13e-  1
## 16 Expend         422.       80.6     5.24  2.29e-  7
## 17 Grad.Rate      168.       61.1     2.76  6.03e-  3
## 18 Private_Yes   -224.       75.9    -2.95  3.28e-  3
summary(full_lm %>% extract_fit_parsnip() %>% extract_fit_engine())
## 
## Call:
## stats::lm(formula = ..y ~ ., data = data)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -4901.3  -446.8    -2.0   353.4  7284.2 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)    
## (Intercept)  3109.14      45.11  68.917  < 2e-16 ***
## Accept       4122.67     116.29  35.453  < 2e-16 ***
## Enroll      -1205.09     217.06  -5.552 4.36e-08 ***
## Top10perc    1120.57     123.04   9.107  < 2e-16 ***
## Top25perc    -424.60     110.44  -3.845 0.000135 ***
## F.Undergrad   573.05     197.70   2.899 0.003893 ** 
## P.Undergrad    61.87      58.33   1.061 0.289295    
## Outstate     -391.69      92.29  -4.244 2.56e-05 ***
## Room.Board    168.30      65.22   2.580 0.010122 *  
## Books         -20.89      48.26  -0.433 0.665239    
## Personal       13.38      52.40   0.255 0.798562    
## PhD          -164.81      94.46  -1.745 0.081558 .  
## Terminal      -36.01      92.78  -0.388 0.698038    
## S.F.Ratio      80.91      63.24   1.279 0.201307    
## perc.alumni    31.26      61.86   0.505 0.613472    
## Expend        422.29      80.61   5.239 2.29e-07 ***
## Grad.Rate     168.31      61.06   2.756 0.006035 ** 
## Private_Yes  -224.19      75.91  -2.953 0.003275 ** 
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 1088 on 564 degrees of freedom
## Multiple R-squared:  0.9302, Adjusted R-squared:  0.9281 
## F-statistic: 442.2 on 17 and 564 DF,  p-value: < 2.2e-16

Lasso logra un buen rmse, descartando 3 variables Ridge: no logro que la curva del lambda tenga una forma de U; que característica del dataset deberia estar mirando